{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Recommendations in Keras using triplet loss\n",
    "Along the lines of BPR [1]. \n",
    "\n",
    "[1] Rendle, Steffen, et al. \"BPR: Bayesian personalized ranking from implicit feedback.\" Proceedings of the Twenty-Fifth Conference on Uncertainty in Artificial Intelligence. AUAI Press, 2009.\n",
    "\n",
    "This is implemented (more efficiently) in LightFM (https://github.com/lyst/lightfm). See the MovieLens example (https://github.com/lyst/lightfm/blob/master/examples/movielens/example.ipynb) for results comparable to this notebook.\n",
    "\n",
    "## Set up the architecture\n",
    "A simple dense layer for both users and items: this is exactly equivalent to latent factor matrix when multiplied by binary user and item indices. There are three inputs: users, positive items, and negative items. In the triplet objective we try to make the positive item rank higher than the negative item for that user.\n",
    "\n",
    "Because we want just one single embedding for the items, we use shared weights for the positive and negative item inputs (a siamese architecture).\n",
    "\n",
    "This is all very simple but could be made arbitrarily complex, with more layers, conv layers and so on. I expect we'll be seeing a lot of papers doing just that.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using Theano backend.\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "Triplet loss network example for recommenders\n",
    "\"\"\"\n",
    "\n",
    "from __future__ import print_function\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from keras import backend as K\n",
    "from keras.models import Model\n",
    "from keras.layers import Embedding, Flatten, Input, merge\n",
    "from keras.optimizers import Adam\n",
    "\n",
    "import data\n",
    "import metrics\n",
    "\n",
    "\n",
    "def identity_loss(y_true, y_pred):\n",
    "\n",
    "    return K.mean(y_pred - 0 * y_true)\n",
    "\n",
    "\n",
    "def bpr_triplet_loss(X):\n",
    "\n",
    "    positive_item_latent, negative_item_latent, user_latent = X\n",
    "\n",
    "    # BPR loss\n",
    "    loss = 1.0 - K.sigmoid(\n",
    "        K.sum(user_latent * positive_item_latent, axis=-1, keepdims=True) -\n",
    "        K.sum(user_latent * negative_item_latent, axis=-1, keepdims=True))\n",
    "\n",
    "    return loss\n",
    "\n",
    "\n",
    "def build_model(num_users, num_items, latent_dim):\n",
    "\n",
    "    positive_item_input = Input((1, ), name='positive_item_input')\n",
    "    negative_item_input = Input((1, ), name='negative_item_input')\n",
    "\n",
    "    # Shared embedding layer for positive and negative items\n",
    "    item_embedding_layer = Embedding(\n",
    "        num_items, latent_dim, name='item_embedding', input_length=1)\n",
    "\n",
    "    user_input = Input((1, ), name='user_input')\n",
    "\n",
    "    positive_item_embedding = Flatten()(item_embedding_layer(\n",
    "        positive_item_input))\n",
    "    negative_item_embedding = Flatten()(item_embedding_layer(\n",
    "        negative_item_input))\n",
    "    user_embedding = Flatten()(Embedding(\n",
    "        num_users, latent_dim, name='user_embedding', input_length=1)(\n",
    "            user_input))\n",
    "\n",
    "    loss = merge(\n",
    "        [positive_item_embedding, negative_item_embedding, user_embedding],\n",
    "        mode=bpr_triplet_loss,\n",
    "        name='loss',\n",
    "        output_shape=(1, ))\n",
    "\n",
    "    model = Model(\n",
    "        input=[positive_item_input, negative_item_input, user_input],\n",
    "        output=loss)\n",
    "    model.compile(loss=identity_loss, optimizer=Adam())\n",
    "\n",
    "    return model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load and transform data\n",
    "We're going to load the Movielens 100k dataset and create triplets of (user, known positive item, randomly sampled negative item).\n",
    "\n",
    "The success metric is AUC: in this case, the probability that a randomly chosen known positive item from the test set is ranked higher for a given user than a ranomly chosen negative item."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "____________________________________________________________________________________________________\n",
      "Layer (type)                     Output Shape          Param #     Connected to                     \n",
      "====================================================================================================\n",
      "positive_item_input (InputLayer) (None, 1)             0                                            \n",
      "____________________________________________________________________________________________________\n",
      "negative_item_input (InputLayer) (None, 1)             0                                            \n",
      "____________________________________________________________________________________________________\n",
      "user_input (InputLayer)          (None, 1)             0                                            \n",
      "____________________________________________________________________________________________________\n",
      "item_embedding (Embedding)       (None, 1, 100)        168300      positive_item_input[0][0]        \n",
      "                                                                   negative_item_input[0][0]        \n",
      "____________________________________________________________________________________________________\n",
      "user_embedding (Embedding)       (None, 1, 100)        94400       user_input[0][0]                 \n",
      "____________________________________________________________________________________________________\n",
      "flatten_7 (Flatten)              (None, 100)           0           item_embedding[0][0]             \n",
      "____________________________________________________________________________________________________\n",
      "flatten_8 (Flatten)              (None, 100)           0           item_embedding[1][0]             \n",
      "____________________________________________________________________________________________________\n",
      "flatten_9 (Flatten)              (None, 100)           0           user_embedding[0][0]             \n",
      "____________________________________________________________________________________________________\n",
      "loss (Merge)                     (None, 1)             0           flatten_7[0][0]                  \n",
      "                                                                   flatten_8[0][0]                  \n",
      "                                                                   flatten_9[0][0]                  \n",
      "====================================================================================================\n",
      "Total params: 262700\n",
      "____________________________________________________________________________________________________\n",
      "None\n",
      "AUC before training 0.50247407966\n"
     ]
    }
   ],
   "source": [
    "latent_dim = 100\n",
    "num_epochs = 10\n",
    "\n",
    "# Read data\n",
    "train, test = data.get_movielens_data()\n",
    "num_users, num_items = train.shape\n",
    "\n",
    "# Prepare the test triplets\n",
    "test_uid, test_pid, test_nid = data.get_triplets(test)\n",
    "\n",
    "model = build_model(num_users, num_items, latent_dim)\n",
    "\n",
    "# Print the model structure\n",
    "print(model.summary())\n",
    "\n",
    "# Sanity check, should be around 0.5\n",
    "print('AUC before training %s' % metrics.full_auc(model, test))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run the model\n",
    "Run for a couple of epochs, checking the AUC after every epoch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0\n",
      "AUC 0.905896400776\n",
      "Epoch 1\n",
      "AUC 0.908241780938\n",
      "Epoch 2\n",
      "AUC 0.909650205748\n",
      "Epoch 3\n",
      "AUC 0.910820451523\n",
      "Epoch 4\n",
      "AUC 0.912184845152\n",
      "Epoch 5\n",
      "AUC 0.912632057958\n",
      "Epoch 6\n",
      "AUC 0.91326604222\n",
      "Epoch 7\n",
      "AUC 0.913786881853\n",
      "Epoch 8\n",
      "AUC 0.914638438854\n",
      "Epoch 9\n",
      "AUC 0.915375014253\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(num_epochs):\n",
    "\n",
    "    print('Epoch %s' % epoch)\n",
    "\n",
    "    # Sample triplets from the training data\n",
    "    uid, pid, nid = data.get_triplets(train)\n",
    "\n",
    "    X = {\n",
    "        'user_input': uid,\n",
    "        'positive_item_input': pid,\n",
    "        'negative_item_input': nid\n",
    "    }\n",
    "\n",
    "    model.fit(X,\n",
    "              np.ones(len(uid)),\n",
    "              batch_size=64,\n",
    "              nb_epoch=1,\n",
    "              verbose=0,\n",
    "              shuffle=True)\n",
    "\n",
    "    print('AUC %s' % metrics.full_auc(model, test))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The AUC is in the low-90s. At some point we start overfitting, so it would be a good idea to stop early or add some regularization."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
